package com.lw.leetcode.tree.b;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

/**
 * Created with IntelliJ IDEA.
 * 2477. 到达首都的最少油耗
 *
 * @author liw
 * @version 1.0
 * @date 2022/11/22 13:31
 */
public class MinimumFuelCost {

    public static void main(String[] args) {
        MinimumFuelCost test = new MinimumFuelCost();

        // 3
        int[][] roads = {{0, 1}, {0, 2}, {0, 3}};
        int seats = 1;

        // 7
//        int[][] roads = {{3, 1}, {3, 2}, {1, 0}, {0, 4}, {0, 5}, {4, 6}};
//        int seats = 2;

        // 4
//        int[][] roads = {{2, 1}, {3, 2}, {1, 0}};
//        int seats = 2;

        // 0
//        int[][] roads = {};
//        int seats = 1;

        // 0
//        int[][] roads = {{0, 1}};
//        int seats = 1;

        // 4
//        int[][] roads = {{0, 1}, {0, 2}, {1, 3}, {1, 4}};
//        int seats = 5;

        long l = test.minimumFuelCost(roads, seats);
        System.out.println(l);

    }

    private Map<Integer, List<Integer>> map = new HashMap<>();
    private int seats;

    public long minimumFuelCost(int[][] roads, int seats) {
        if (roads.length == 0) {
            return 0;
        }
        this.seats = seats;
        for (int[] road : roads) {
            int a = road[0];
            int b = road[1];
            map.computeIfAbsent(a, v -> new ArrayList<>()).add(b);
            map.computeIfAbsent(b, v -> new ArrayList<>()).add(a);
        }
        Node root = new Node(0);
        find(-1, root);
        find(root);
        return root.sum;
    }

    private void find(Node node) {
        if (node.list.isEmpty()) {
            node.cat = 1;
            node.used = 1;
            return;
        }
        for (Node no : node.list) {
            find(no);
            node.sum += no.sum;
            node.sum += no.cat;
            node.used += no.used;
            node.cat += no.cat;
        }
        node.cat = (node.used + seats - 1) / seats;
        if (node.cat * seats == node.used) {
            node.cat++;
        }
        node.used++;
    }

    private void find(int p, Node node) {
        int val = node.val;
        List<Integer> list = map.get(val);
        for (Integer v : list) {
            if (v == p) {
                continue;
            }
            Node no = new Node(v);
            node.list.add(no);
            find(val, no);
        }
    }

    private class Node {
        private int val;
        private List<Node> list = new ArrayList<>();
        private int cat;
        private long sum;
        private int used;

        public Node(int val) {
            this.val = val;
        }
    }

}
